import torch
import torch.nn as nn
import torchvision.transforms as transforms
from adversarialbox.utils import to_var, test
import torchvision
from setbitnumber import setBitNumber
from hamming import solve
import numpy as np
from tensorboardX import SummaryWriter
from layers_resnet2032 import *

from layers_br_test import bit_reduction_test, train, select_one_parameter_per_page, update_parameters

from resnet import *
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="1"
TRAIN=0
# parameters ###################################################
PAGE_CHECK=True
GLOBAL=True
Nflip=20
targets=2
start=21
end=31 
high=100		
logdir='/experiments_resnet32/CFTBR/Nflip='+str(Nflip)+'/'

writer = SummaryWriter(logdir=logdir)
# Hyper-parameters
param = {
	'batch_size': 256,
	'test_batch_size': 256,
	'num_epochs':250,
	'delay': 251,
	'learning_rate': 0.001,
	'weight_decay': 1e-6,
}
inf_with_weight = False  # disabled by default
N_bits = 8
full_lvls = 2**N_bits
half_lvls = (full_lvls - 2) / 2
####################################################################

def main():
	print('==> Preparing data..')
	#transform_train = transforms.Compose([
	#	transforms.RandomHorizontalFlip(),
	#	transforms.RandomCrop(32, padding=4),
	#	transforms.ToTensor(),
	#	normalize
	#])
	transform_test = transforms.Compose(
			[transforms.ToTensor()])
			 #transforms.Normalize(mean, std)


	#trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 

	#loader_train = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 

	mean=[0.485, 0.456, 0.406],
	std=[0.229, 0.224, 0.225]
	normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
									 std=[0.229, 0.224, 0.225])
	testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
			transforms.ToTensor(),
			normalize,
		]))
	loader_test = torch.utils.data.DataLoader(testset,
									batch_size=128, shuffle=False,
									num_workers=4, pin_memory=True)

	net_c = quan_ResNet32() 
	net = torch.nn.Sequential(
						#Normalize_layer(mean,std),
						net_c
						)

	net_f = quan_ResNet32() 
	net1 = torch.nn.Sequential(
						#Normalize_layer(mean,std),
						net_f
						)

	net_d = quan_ResNet32_() 
	net2 = torch.nn.Sequential(
						#Normalize_layer(mean,std),
						net_d
						)  
	net_c.cuda()
	state_dict = torch.load('pretrained_models/resnet32-d509ac18.th')
	#net.load_state_dict(torch.load(logdir+'resnet34_8bit.pkl'),strict=False) 
	ctr=0
	ctr1=0

	

	# define loss function (criterion) and optimizer
	criterion = nn.CrossEntropyLoss().cuda()


	#weight_int_conversion(net)

	#torch.save(net.state_dict(), 'deneme.pt') 

	criterion = nn.CrossEntropyLoss()
	criterion=criterion.cuda()

	#net=net.cuda()
	#test(net,loader_test)
	#torch.load('deneme.pt')

	if TRAIN:
		#Loading the weights
		
		for name, layer in state_dict['state_dict'].items():
			tmp = name.replace('module.','')
			ctr1+=1
			for name1, layer1 in net.state_dict().items():
				tmp1 = name1.replace('0.','',1)
				if tmp==tmp1:
					net.load_state_dict({name1:layer.data}, strict=False)
					ctr+=1
		#net_state_dict=torch.load(logdir+'Resnet18_8bit_all_layers_trojan.pkl')
		print(ctr,ctr1)		
		net.train()
		net=net.cuda()
		net1.load_state_dict(net.state_dict())
		net1=net1.cuda()
		net2.load_state_dict(net.state_dict())	
		net2=net2.cuda()
	else:
		#Loading the weights
		net_state_dict=torch.load(logdir+'Resnet18_8bit_all_layers_trojan.pkl')
		del net_state_dict['0.conv1.step_size']
		del net_state_dict['0.linear.step_size']
		net.load_state_dict(net_state_dict,strict=False) #torch.load(logdir+'Resnet18_8bit_all_layers_trojan.pkl'),strict=False) 
		for name, layer in state_dict['state_dict'].items():
			tmp = name.replace('module.','')
			ctr1+=1
			for name1, layer1 in net1.state_dict().items():
				tmp1 = name1.replace('0.','',1)
				if tmp==tmp1:
					net1.load_state_dict({name1:layer.data}, strict=False)
					ctr+=1

		net.train()
		net=net.cuda()
		net1=net1.cuda()
		net2.load_state_dict(net1.state_dict())	
		net2=net2.cuda()
	

	test(net1,loader_test)

	if torch.cuda.is_available():
		print('CUDA enabled.')
		net.cuda()


	criterion = nn.CrossEntropyLoss()
	criterion=criterion.cuda()

	##_-----------------------------------------NGR step------------------------------------------------------------
	## performing back propagation to identify the target neurons using a sample test batch of size 128

	for batch_idx, (data, target) in enumerate(loader_test):
		data, target = data.cuda(), target.cuda()
		mins,maxs=data.min(),data.max()
		break
	
	#x_tri[0,:,:] = torch.from_numpy(np.loadtxt('trojan_last_layer_img1.txt', dtype=float))
	#x_tri[1,:,:] = torch.from_numpy(np.loadtxt('trojan_last_layer_img2.txt', dtype=float))
	#x_tri[2,:,:] = torch.from_numpy(np.loadtxt('trojan_last_layer_img3.txt', dtype=float))
	#	
	#print('testing before trojan insertion')
#
	#test1(net,loader_test,x_tri, start, end, targets, False)
	#test(net,loader_test)
	##return
#
#
	
	x_tri = data.clone()#.data[0,:,:,:]
	x_tri *= 0
	x_tri[0:3,start:end,start:end] += 255
	x_var, y_var = to_var(data), to_var(target.long()) 
	y_var[:]=targets

	net.eval()

	best_loss = 999
	if TRAIN:
		for n in range(200):
			output = net(x_var)
			loss = criterion(output, y_var)

			for m in net.modules():
				if hasattr(m,'weight'):#if isinstance(m, quantized_conv) or isinstance(m, bilinear):
					if m.weight.grad is not None:
						m.weight.grad.data.zero_()
							
			loss.backward()

			#print('net1 evaluation')
			#test1(net1,loader_test,x_tri, start, end, targets, False)
			#test(net1,loader_test)
			#input()
			param1 = list(net1.parameters())[94]
			param = list(net.parameters())[94]
			w_v,w_id=param.grad.detach().abs().topk(Nflip) ## taking only 200 weights thus Nflip=200
			tar=w_id[targets]
			#-----------------------Trigger Generation----------------------------------------------------------------

			### taking any random test image to create the mask
			loader_test = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)
			
			for t, (x, y) in enumerate(loader_test): 
					x_var, y_var = to_var(x), to_var(y.long()) 
					x_var[:,:,:,:]=0
					x_var[:,0:3,start:end,start:end]=x_tri[0,0:3,start:end,start:end] ## initializing the mask to 0.5   
					break

			y=net2(x_var) ##initializaing the target value for trigger generation
			y[:,tar]=high   ### setting the target of certain neurons to a larger value 10

			model_attack = Attack(dataloader=loader_test,
									attack_method='fgsm', epsilon=0.001)

			### iterating 200 times to generate the trigger
			for ep in [0.5, 0.1, 0.01, 0.001]:
				for i in range(200):  
					x_tri=model_attack.attack_method(
								net2, x_var.cuda(), y,tar,ep,start, end,mins,maxs) 
					x_var=x_tri
				
			#saving the trigger image channels for future use
			np.savetxt(logdir+'trojan_last_layer_img1.txt', x_tri[0,0,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img2.txt', x_tri[0,1,:,:].cpu().numpy(), fmt='%f')
			np.savetxt(logdir+'trojan_last_layer_img3.txt', x_tri[0,2,:,:].cpu().numpy(), fmt='%f')
			print(n)
			#train(n,net,net1,Nflip, testset, criterion, x_tri,start, end, targets,writer, logdir)
			best_loss =  train(n,net,net1,Nflip, testset, criterion, x_tri,start, end, targets,best_loss,writer,logdir, PAGE_CHECK, GLOBAL)
			zz2 = net.state_dict()
			zz2['0.linear.step_size'] = torch.reshape(zz2['0.linear.step_size'],(1,))
			zz2['0.conv1.step_size'] = torch.reshape(zz2['0.conv1.step_size'],(1,))
			net2.load_state_dict(zz2,strict=False)
		#x_tri[0,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img1.txt', dtype=float))
		#x_tri[1,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img2.txt', dtype=float))
		#x_tri[2,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img3.txt', dtype=float))

	else:
		x_tri = data.clone().data[0,:,:,:]
		x_tri[0,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img1.txt', dtype=float)) #logdir+'trojan_last_layer_img1.txt', dtype=float))
		x_tri[1,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img2.txt', dtype=float)) #logdir+'trojan_last_layer_img2.txt', dtype=float))
		x_tri[2,:,:] = torch.from_numpy(np.loadtxt(logdir+'trojan_last_layer_img3.txt', dtype=float)) #logdir+'trojan_last_layer_img3.txt', dtype=float))
		x_var, y_var = to_var(data), to_var(target.long()) 
		y_var[:]=targets
		output = net(x_var)
		loss = criterion(output, y_var)

		for m in net.modules():
			if hasattr(m,'weight'):#if isinstance(m, quantized_conv) or isinstance(m, bilinear):
				if m.weight.grad is not None:
					m.weight.grad.data.zero_()
						
		loss.backward()		
		print("Clean model:")
		test1(net1,loader_test,x_tri, start, end, targets, TRAIN) 
		test(net1,loader_test) 
		print("Fine Tuned model:")
		test1(net,loader_test,x_tri, start, end, targets, TRAIN) 
		test(net,loader_test) 

		layer_indices = select_one_parameter_per_page(net,net1,Nflip,PAGE_CHECK=1)
		net = update_parameters(net,net1,layer_indices)
		net=bit_reduction_test(net,net1,Nflip,targets)
		print("Trojanad model:")
		test1(net,loader_test,x_tri, start, end, targets, TRAIN) 
		test(net,loader_test) 
		writer.close()



if __name__ == "__main__":
	main()
	